--- title: fastai unet + feature_loss keywords: fastai sidebar: home_sidebar ---
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
seed = 8610
random.seed(seed)
np.random.seed(seed)
gram_matrix[source]
gram_matrix(x)
class FeatureLoss[source]
FeatureLoss(m_feat,layer_ids,layer_wgts) ::Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:to, etc.
train_hr = div2k_train_hr_crop_256
in_size = 256
out_size = 256
scale = 4
bs = 10
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, bs=bs, seed=seed)
print(data)
data.show_batch()
ImageDataBunch; Train: LabelList (25245 items) x: ImageImageList Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256) y: ImageImageList Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256) Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256; Valid: LabelList (6311 items) x: ImageImageList Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256) y: ImageImageList Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256) Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256; Test: None
model = models.resnet34
loss_func = create_feature_loss()
metrics = [m_psnr, m_ssim]
callback_fns = LossMetrics
wd = 1e-3
y_range = (-3.,3.)
model_name = 'feat_loss'
learn = unet_learner(data, model, wd=wd, metrics=metrics, y_range=y_range,
loss_func=loss_func, callback_fns=callback_fns,
blur=True, norm_type=NormType.Weight, self_attention=True)
learn.path = Path('.')
lr_find(learn)
learn.recorder.plot(suggestion=True)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 2.29E-04
Min loss divided by 10: 6.31E-04
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
| epoch | train_loss | valid_loss | m_psnr | m_ssim | pixel | feat_0 | feat_1 | feat_2 | gram_0 | gram_1 | gram_2 | time |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.607700 | 1.489443 | 31.873768 | 0.411500 | 0.145787 | 0.194045 | 0.212579 | 0.065989 | 0.365943 | 0.422427 | 0.082674 | 13:55 |
| 1 | 1.483280 | 1.377554 | 32.458805 | 0.415973 | 0.137284 | 0.188241 | 0.202604 | 0.063084 | 0.321470 | 0.385370 | 0.079502 | 13:53 |
| 2 | 1.435708 | 1.312828 | 33.920448 | 0.426187 | 0.128036 | 0.184923 | 0.198252 | 0.061971 | 0.298555 | 0.362909 | 0.078183 | 13:53 |
Better model found at epoch 0 with valid_loss value: 1.489443302154541. Better model found at epoch 1 with valid_loss value: 1.377553939819336. Better model found at epoch 2 with valid_loss value: 1.312827706336975.
learn.show_results()
test_hr = set14_hr
il_test_x = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=in_size, scale=4, sizeup=True))
il_test_y = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size))
_ = learn.load(save_fname)
sr_test(learn, il_test_x, il_test_y, model_name)
bicubic: PSNR:24.11,SSIM:0.7822 feat_loss: PSNR:25.13,SSIM:0.8130
model
<function torchvision.models.resnet.resnet34(pretrained=False, progress=True, **kwargs)>
learn.summary()
DynamicUnet ====================================================================== Layer (type) Output Shape Param # Trainable ====================================================================== Conv2d [64, 128, 128] 9,408 False ______________________________________________________________________ BatchNorm2d [64, 128, 128] 128 True ______________________________________________________________________ ReLU [64, 128, 128] 0 False ______________________________________________________________________ MaxPool2d [64, 64, 64] 0 False ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ ReLU [64, 64, 64] 0 False ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ ReLU [64, 64, 64] 0 False ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ ReLU [64, 64, 64] 0 False ______________________________________________________________________ Conv2d [64, 64, 64] 36,864 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ Conv2d [128, 32, 32] 73,728 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ ReLU [128, 32, 32] 0 False ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [128, 32, 32] 8,192 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ ReLU [128, 32, 32] 0 False ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ ReLU [128, 32, 32] 0 False ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ ReLU [128, 32, 32] 0 False ______________________________________________________________________ Conv2d [128, 32, 32] 147,456 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [256, 16, 16] 294,912 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 32,768 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ ReLU [256, 16, 16] 0 False ______________________________________________________________________ Conv2d [256, 16, 16] 589,824 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [512, 8, 8] 1,179,648 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ ReLU [512, 8, 8] 0 False ______________________________________________________________________ Conv2d [512, 8, 8] 2,359,296 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ Conv2d [512, 8, 8] 131,072 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ Conv2d [512, 8, 8] 2,359,296 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ ReLU [512, 8, 8] 0 False ______________________________________________________________________ Conv2d [512, 8, 8] 2,359,296 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ Conv2d [512, 8, 8] 2,359,296 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ ReLU [512, 8, 8] 0 False ______________________________________________________________________ Conv2d [512, 8, 8] 2,359,296 False ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ BatchNorm2d [512, 8, 8] 1,024 True ______________________________________________________________________ ReLU [512, 8, 8] 0 False ______________________________________________________________________ Conv2d [1024, 8, 8] 4,719,616 True ______________________________________________________________________ ReLU [1024, 8, 8] 0 False ______________________________________________________________________ Conv2d [512, 8, 8] 4,719,104 True ______________________________________________________________________ ReLU [512, 8, 8] 0 False ______________________________________________________________________ Conv2d [1024, 8, 8] 525,312 True ______________________________________________________________________ PixelShuffle [256, 16, 16] 0 False ______________________________________________________________________ ReplicationPad2d [256, 17, 17] 0 False ______________________________________________________________________ AvgPool2d [256, 16, 16] 0 False ______________________________________________________________________ ReLU [1024, 8, 8] 0 False ______________________________________________________________________ BatchNorm2d [256, 16, 16] 512 True ______________________________________________________________________ Conv2d [512, 16, 16] 2,359,808 True ______________________________________________________________________ ReLU [512, 16, 16] 0 False ______________________________________________________________________ Conv2d [512, 16, 16] 2,359,808 True ______________________________________________________________________ ReLU [512, 16, 16] 0 False ______________________________________________________________________ ReLU [512, 16, 16] 0 False ______________________________________________________________________ Conv2d [1024, 16, 16] 525,312 True ______________________________________________________________________ PixelShuffle [256, 32, 32] 0 False ______________________________________________________________________ ReplicationPad2d [256, 33, 33] 0 False ______________________________________________________________________ AvgPool2d [256, 32, 32] 0 False ______________________________________________________________________ ReLU [1024, 16, 16] 0 False ______________________________________________________________________ BatchNorm2d [128, 32, 32] 256 True ______________________________________________________________________ Conv2d [384, 32, 32] 1,327,488 True ______________________________________________________________________ ReLU [384, 32, 32] 0 False ______________________________________________________________________ Conv2d [384, 32, 32] 1,327,488 True ______________________________________________________________________ ReLU [384, 32, 32] 0 False ______________________________________________________________________ Conv1d [48, 1024] 18,432 True ______________________________________________________________________ Conv1d [48, 1024] 18,432 True ______________________________________________________________________ Conv1d [384, 1024] 147,456 True ______________________________________________________________________ ReLU [384, 32, 32] 0 False ______________________________________________________________________ Conv2d [768, 32, 32] 295,680 True ______________________________________________________________________ PixelShuffle [192, 64, 64] 0 False ______________________________________________________________________ ReplicationPad2d [192, 65, 65] 0 False ______________________________________________________________________ AvgPool2d [192, 64, 64] 0 False ______________________________________________________________________ ReLU [768, 32, 32] 0 False ______________________________________________________________________ BatchNorm2d [64, 64, 64] 128 True ______________________________________________________________________ Conv2d [256, 64, 64] 590,080 True ______________________________________________________________________ ReLU [256, 64, 64] 0 False ______________________________________________________________________ Conv2d [256, 64, 64] 590,080 True ______________________________________________________________________ ReLU [256, 64, 64] 0 False ______________________________________________________________________ ReLU [256, 64, 64] 0 False ______________________________________________________________________ Conv2d [512, 64, 64] 131,584 True ______________________________________________________________________ PixelShuffle [128, 128, 128] 0 False ______________________________________________________________________ ReplicationPad2d [128, 129, 129] 0 False ______________________________________________________________________ AvgPool2d [128, 128, 128] 0 False ______________________________________________________________________ ReLU [512, 64, 64] 0 False ______________________________________________________________________ BatchNorm2d [64, 128, 128] 128 True ______________________________________________________________________ Conv2d [96, 128, 128] 165,984 True ______________________________________________________________________ ReLU [96, 128, 128] 0 False ______________________________________________________________________ Conv2d [96, 128, 128] 83,040 True ______________________________________________________________________ ReLU [96, 128, 128] 0 False ______________________________________________________________________ ReLU [192, 128, 128] 0 False ______________________________________________________________________ Conv2d [384, 128, 128] 37,248 True ______________________________________________________________________ PixelShuffle [96, 256, 256] 0 False ______________________________________________________________________ ReLU [384, 128, 128] 0 False ______________________________________________________________________ MergeLayer [99, 256, 256] 0 False ______________________________________________________________________ Conv2d [99, 256, 256] 88,308 True ______________________________________________________________________ ReLU [99, 256, 256] 0 False ______________________________________________________________________ Conv2d [99, 256, 256] 88,308 True ______________________________________________________________________ ReLU [99, 256, 256] 0 False ______________________________________________________________________ MergeLayer [99, 256, 256] 0 False ______________________________________________________________________ Conv2d [3, 256, 256] 300 True ______________________________________________________________________ SigmoidRange [3, 256, 256] 0 False ______________________________________________________________________ Total params: 41,405,588 Total trainable params: 20,137,940 Total non-trainable params: 21,267,648 Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99) Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ Loss function : FeatureLoss ====================================================================== Callbacks functions applied